# dataLoaderETH.py
#
# Sample code for load and visualizing additional semantic-segmentation label
# for "shapes_translation" sequence in ETHTED.

# See readme.txt for detail.
#
# Y.Sekikawa 2018/03/29

import matplotlib.pyplot as pl
import torch
import numpy as np
import os

dtype = torch.float32
itype = torch.long


class eventData(object):
    def __init__(self, device, dataPath, W, H, tau, bSort=False):
        # self.init0()
        self.tau = tau
        self.temporal_res = 1000
        self.W = W
        self.H = H
        self.bSort = bSort
        self.device = device
        # self.labelType = ['gt', 'seg', 'imu']
        self.labelType = ['seg']

        if os.path.exists(dataPath + 'cashe.npy'):
            print('Resuming from cashe')
            event, gt, imu, seg = np.load(dataPath + 'cashe.npy')
        else:
            print('Loading from txt')
            event = np.loadtxt(dataPath + 'events.txt')
            gt = np.loadtxt(dataPath + 'groundtruth.txt') if 'gt' in self.labelType else []
            imu = np.loadtxt(dataPath + 'imu.txt') if 'imu' in self.labelType else []
            seg = np.loadtxt(dataPath + 'segmentation.txt') if 'seg' in self.labelType else []
            np.save(dataPath + 'cashe.npy', [event, gt, imu, seg])

        self.event, self.gt, self.imu, self.seg = event, gt, imu, seg

    def getTimeRangeTrain(self):
        start_time = np.min(self.event[:, 0])
        end_time = 50.0 - self.tau/self.temporal_res
        return [start_time, end_time]

    def getTimeRangeTest(self):
        start_time = 50.0
        end_time = np.max(self.event[:, 0]) - self.tau/self.temporal_res
        return [start_time, end_time]

    def cropAt(self, t_s):
        t_e = t_s + self.tau/self.temporal_res
        i_s = np.nonzero(self.event[:, 0] > t_s)[0][0]
        i_e = np.nonzero(self.event[:, 0] > t_e)[0][0]-1
        self.t_s = t_s

        t = np.round(self.event[i_s:i_e, 0] * self.temporal_res)

        self.event_ = np.stack((self.event[i_s:i_e, 1], self.event[i_s:i_e, 2], self.event[i_s:i_e, 3], t-t[0]),  axis=1)
        self.gt_ = (self.gt[np.argmax(self.gt[:, 0] > t_e), :]).reshape([1, -1]) if 'gt' in self.labelType else []
        self.imu_ = (self.imu[np.argmax(self.imu[:, 0] > t_e), :]).reshape([1, -1]) if 'imu' in self.labelType else []
        self.seg_ = self.seg[i_s:i_e].reshape([-1,1]) if 'seg' in self.labelType else []

        return torch.tensor(self.event_).type(dtype).to(self.device), \
               torch.tensor(self.gt_).type(dtype).to(self.device), \
               torch.tensor(self.imu_).type(dtype).to(self.device), \
               torch.tensor(self.seg_).type(itype).to(self.device)

    def plot(self):
        col_r = [1.0, 0.0, 0.0]
        col_g = [0.0, 1.0, 0.0]
        col_b = [0.0, 0.0, 1.0]
        col_w = [1.0, 1.0, 1.0]

        x = self.event_[:, 0]
        y = self.event_[:, 1]
        p = self.event_[:, 2]
        t = self.event_[:, 3]
        s = self.seg_[:, 0] if 'seg' in self.labelType else []

        t = (t / self.tau)
        print(str(len(t) / self.tau) + 'KEPS')

        # Plot
        fig = pl.figure()
        if 1:
            ax = pl.subplot(1, 2, 1)
            ax.scatter(x[p == 1], y[p == 1],
                       color=np.multiply(t[p == 1].reshape([-1, 1]), np.array(col_g).reshape([1, 3])), s=1)
            ax.scatter(x[p == 0], y[p == 0],
                       color=np.multiply(t[p == 0].reshape([-1, 1]), np.array(col_r).reshape([1, 3])), s=1)
            pl.gca().invert_yaxis()
            pl.xlim(0, self.W)
            pl.ylim(0, self.H)
            ax.set_aspect('equal')
            ax.axis('off')
            pl.title('Porality, t:' + str(self.t_s))

        if 'seg' in self.labelType:
            ax = pl.subplot(1, 2, 2)
            ax.scatter(x[s == 2], y[s == 2],
                       color=np.multiply(t[s == 2].reshape([-1, 1]), np.array(col_b).reshape([1, 3])), s=1)
            ax.scatter(x[s <= 1], y[s <= 1],
                       color=np.multiply(t[s <= 1].reshape([-1, 1]), np.array(col_w).reshape([1, 3])), s=1)
            pl.gca().invert_yaxis()
            pl.xlim(0, self.W)
            pl.ylim(0, self.H)
            ax.set_aspect('equal')
            ax.axis('off')
            pl.title('Segmentation, t:' + str(self.t_s))

        # pl.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0, hspace=0)
        # # fig.tight_layout()
        pl.show()


if __name__ == '__main__':
    # dataPath = '/Users/sekikawayuusuke/Downloads/shapes_rotation/'
    dataPath = '/Users/ysekikawa/data/ETHTED/shapes_rotation/'

    W, H, tau = 240, 180, 32.
    dataLoader = eventData('cpu', dataPath, 240, 180, tau)

    [start_time, end_time] = dataLoader.getTimeRangeTrain()
    for _ in range(100):
        t = np.random.uniform(start_time, end_time, 1)
        event, gt, imu, seg = dataLoader.cropAt(t)
        dataLoader.plot()
